classdef GPOPSFunctions < handle
    
    properties
        prod_fun % ProductionFunction instance
        util_fun % UtiliyFunction instance
        skill_dist % SkillDistribution instance
        part_dist % ParticipationCostDistribution instance
        
        r % inequality aversion
        q_B % price of robots
        q_E % price of equipment capital
        q_S % price of structures
        revenue_requirement % exog. revenue requirement
        phi_lb % lower bound for integration over part. costs
    end
    
    properties (SetAccess = protected)
        num_nodes % number of Gauss-Legendre nodes
        nodes % computed integration nodes
        weights % computed integration weights
    end
    
    methods
        
        function self = GPOPSFunctions(prod_fun, util_fun, skill_dist, ...
                part_dist, par)
            self.prod_fun = prod_fun;
            self.util_fun = util_fun;
            self.skill_dist = skill_dist;
            self.part_dist = part_dist;
            self.r = par.r;
            self.q_B = par.q_B;
            self.q_E = par.q_E;
            self.q_S = par.q_S;
            self.revenue_requirement = par.revenue_requirement;
            self.num_nodes = par.num_nodes;
            self.phi_lb = par.phi_lb;
            
            [self.nodes, self.weights] = ...
                tools.Integration.lgwt(self.num_nodes,-1,1);
        end
        
        
        function self = compute_nodes(self, num_nodes)
            self.num_nodes = num_nodes;
            [self.nodes, self.weights] = ...
                tools.Integration.lgwt(self.num_nodes,-1,1);
        end
        
        function out = labor_market_clearing_M(self, var)
            out =  (var.l_w(:).*var.w(:).*var.h_M_w(:)) .* var.wscale;
        end
        
        function out = labor_market_clearing_R(self, var)
            out =  (var.l_w(:).*var.w(:).*var.h_R_w(:)) .* var.wscale;
        end
        
        function out = labor_market_clearing_C(self, var)
            out =  (var.l_w(:).*var.w(:).*var.h_C_w(:)) .* var.wscale;
        end
        
        function out = resource_constraint(self,var)
            rc_participants = (var.l_w(:).*var.w(:) - var.c_w(:)).*var.h_w(:);
            rc_non_participants = -var.transfer.*(var.f_w(:) - var.h_w(:));
            
            rc_capital_expenditure =  (var.K_B * self.q_B + ...
                var.K_E * self.q_E + ...
                var.K_S * self.q_S).*var.f_w(:);
            
            rc_exog_revenue = self.revenue_requirement.*var.f_w(:);
            
            rc_capital_income = (var.K_B_M .* var.Y_B_M + ...
                var.K_B_R .* var.Y_B_R + ...
                var.K_B_C .* var.Y_B_C + ...
                var.K_E .* var.Y_E + ...
                var.K_S .* var.Y_S).*var.f_w(:);
            
            out = (rc_participants ...
                + rc_non_participants ...
                + rc_capital_income ...
                - rc_capital_expenditure ...
                - rc_exog_revenue).* ...
                var.wscale;
        end
        
        function out = dynamics(self, var)
            out = self.util_fun.Udot(var.l_w, var.w).* ...
                var.wscale;
        end
        
        function out = welfare(self, var)
            out = self.welfare_participants_M(var) ...
                + self.welfare_participants_R(var) ...
                + self.welfare_participants_C(var) ...
                + self.welfare_non_participants(var);
        end
        
        function out = welfare_non_participants(self, var)
            out = (self.Psi(var.V_b) .* (var.f_w(:) - var.h_w(:))).*var.wscale;
        end
        
        function out = consumption_M(self, var)
            out = var.c_w.*var.h_M_w;
        end
        
        function out = consumption_R(self, var)
            out = var.c_w.*var.h_R_w;
        end
        
        function out = consumption_C(self, var)
            out = var.c_w.*var.h_C_w;
        end
        
        function out = welfare_participants_M(self, var)
            % numerically integrate over participation costs
            % integration is done of the transformed variable chi(phi),
            % which transforms the domain of integration from
            % [self.phi_lb, var.phi_tilde_w] to [-1,1] and transforms
            % phi to chi accordingly
            
            n_nodes = length(self.nodes);
            
            chi = @(phi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one(phi,lb,ub);
            
            dchi_dphi = @(phi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one_deriv(phi,lb,ub);
            
            fun = @(u,phi,lb,ub) self.Psi(max(u - chi(phi,lb,ub),var.transfer)).*self.part_dist.g_M(chi(phi,lb,ub)).* ...
                dchi_dphi(phi,lb,ub);
            
            nodes = self.nodes(:)';
            
            phi_tilde = var.phi_tilde_w;
            phi_tilde(phi_tilde < self.phi_lb) = self.phi_lb;
            
            u_mat = repmat(var.V_w(:),1,n_nodes);
            lb_mat = self.phi_lb.*ones(size(u_mat));
            ub_mat = repmat(phi_tilde(:),1,n_nodes);
            nodes_mat = repmat(nodes,length(phi_tilde(:)),1);
            
            fun_at_nodes = fun(u_mat,nodes_mat,lb_mat,ub_mat);
            
            out = ((fun_at_nodes * self.weights(:)).* ...
                var.f_M_w(:)).*var.wscale;
        end
        
        
        function out = welfare_participants_R(self, var)
            % see welfare_participants_M
            n_nodes = length(self.nodes);
            
            chi = @(phi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one(phi,lb,ub);
            
            dchi_dphi = @(phi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one_deriv(phi,lb,ub);
            
            fun = @(u,phi,lb,ub) self.Psi(max(u - chi(phi,lb,ub),var.transfer)).*self.part_dist.g_R(chi(phi,lb,ub)).* ...
                dchi_dphi(phi,lb,ub);
            
            nodes = self.nodes(:)';
            
            phi_tilde = var.phi_tilde_w;
            phi_tilde(phi_tilde < self.phi_lb) = self.phi_lb;
            
            u_mat = repmat(var.V_w(:),1,n_nodes);
            lb_mat = self.phi_lb.*ones(size(u_mat));
            ub_mat = repmat(phi_tilde(:),1,n_nodes);
            nodes_mat = repmat(nodes,length(phi_tilde(:)),1);
            
            fun_at_nodes = fun(u_mat,nodes_mat,lb_mat,ub_mat);
            
            out = ((fun_at_nodes * self.weights(:)).* ...
                var.f_R_w(:)).*var.wscale;
        end
        
        
        function out = welfare_participants_C(self, var)
            % see welfare_participants_M
            n_nodes = length(self.nodes);
            
            chi = @(phi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one(phi,lb,ub);
            
            dchi_dphi = @(phi,lb,ub) ...
                tools.ChangeOfVariables.change_variable_minone_one_deriv(phi,lb,ub);
            
            fun = @(u,phi,lb,ub) self.Psi(max(u - chi(phi,lb,ub),var.transfer)).*self.part_dist.g_C(chi(phi,lb,ub)).* ...
                dchi_dphi(phi,lb,ub);
            
            nodes = self.nodes(:)';
            
            phi_tilde = var.phi_tilde_w;
            phi_tilde(phi_tilde < self.phi_lb) = self.phi_lb;
            
            u_mat = repmat(var.V_w(:),1,n_nodes);
            lb_mat = self.phi_lb.*ones(size(u_mat));
            ub_mat = repmat(phi_tilde(:),1,n_nodes);
            nodes_mat = repmat(nodes,length(phi_tilde(:)),1);
            
            fun_at_nodes = fun(u_mat,nodes_mat,lb_mat,ub_mat);
            
            out = ((fun_at_nodes * self.weights(:)).* ...
                var.f_C_w(:)).*var.wscale;
        end
        
        
        function out = Psi(self, V)
            % Social welfare function over indirect utilities
            out = (V.^(1-self.r)-1)./(1-self.r);
        end
        
        
        function out = compute_var_from_input(self, input)
            % compute once per iteration, results used by many
            % functions
            
            % scaling and log-transformation for w
            [out.w, out.wtilde, out.wscale] = self.transform_w(input);
            
            inpscale = input.auxdata.settings.inpscale;
            
            % state and control
            out.V_w = input.phase.state(:,1) .* inpscale.u;
            out.l_w = input.phase.control(:,1) .* inpscale.l;
            out.c_w = self.util_fun.c(out.V_w, out.l_w);
            
            % parameters and derived variables
            parameters = input.phase.parameter(1,:);
            
            out.L_M =   inpscale.L_M .*   parameters(1);
            out.L_R =   inpscale.L_R .*   parameters(2);
            out.L_C =   inpscale.L_C .*   parameters(3);
            out.K_B_M = inpscale.K_B_M .* parameters(4);
            out.K_B_R = inpscale.K_B_R .* parameters(5);
            out.K_B_C = inpscale.K_B_C .* parameters(6);
            out.K_E =   inpscale.K_E .*   parameters(7);
            out.K_S =   inpscale.K_S .*   parameters(8);
            
            out.transfer = inpscale.transfer.* ...
                parameters(9);
            
            out.K_B = out.K_B_M + out.K_B_R + out.K_B_C;
            out.V_b = self.util_fun.U(out.transfer,0);
            
            % factor prices
            out.Y_M = self.prod_fun.Y_M(out);
            out.Y_R = self.prod_fun.Y_R(out);
            out.Y_C = self.prod_fun.Y_C(out);
            out.Y_B_M = self.prod_fun.Y_B_M(out);
            out.Y_B_R = self.prod_fun.Y_B_R(out);
            out.Y_B_C = self.prod_fun.Y_B_C(out);
            out.Y_E = self.prod_fun.Y_E(out);
            out.Y_S = self.prod_fun.Y_S(out);
            
            % densities
            out.f_M_w = self.skill_dist.f_M(out.w,out.Y_M,out.Y_R,out.Y_C);
            out.f_R_w = self.skill_dist.f_R(out.w,out.Y_M,out.Y_R,out.Y_C);
            out.f_C_w = self.skill_dist.f_C(out.w,out.Y_M,out.Y_R,out.Y_C);
            out.f_w = out.f_M_w + out.f_R_w + out.f_C_w;
            
            out.phi_tilde_w = out.V_w - out.transfer;
            
            if input.auxdata.settings.fix_participation
                
                pi_M = input.auxdata.settings.pi_M;
                pi_R = input.auxdata.settings.pi_R;
                pi_C = input.auxdata.settings.pi_C;
                
                out.G_phi_tilde_M_w = pi_M .*ones(size(out.w));
                out.G_phi_tilde_R_w = pi_R .*ones(size(out.w));
                out.G_phi_tilde_C_w = pi_C .*ones(size(out.w));
                
            else
                
                out.G_phi_tilde_M_w = ...
                    self.part_dist.G_M(out.phi_tilde_w);
                
                out.G_phi_tilde_R_w = ...
                    self.part_dist.G_R(out.phi_tilde_w);
                
                out.G_phi_tilde_C_w = ...
                    self.part_dist.G_C(out.phi_tilde_w);
            end
            
            out.h_M_w = out.G_phi_tilde_M_w .* out.f_M_w;
            out.h_R_w = out.G_phi_tilde_R_w .* out.f_R_w;
            out.h_C_w = out.G_phi_tilde_C_w .* out.f_C_w;
            out.h_w = out.h_M_w + out.h_R_w + out.h_C_w;
            
        end
        
        
        function [w, wtilde, wscale] = transform_w(self, input)
            settings = input.auxdata.settings;
            wtilde = input.phase.time;
            
            if settings.log_w==1
                if settings.standardize == 1
                    mu = input.auxdata.mu_fit;
                    sigma = input.auxdata.sigma_fit;
                    w = exp(wtilde.*sigma + mu);
                    wscale = sigma.*w;
                else
                    w = exp(wtilde);
                    wscale = w;
                end
            else
                w = wtilde;
                wscale = 1;
            end
        end
        
        
    end
end